library(mlr)
## Loading required package: ParamHelpers
library(iml)
library(nnet)
library(ggplot2)
library(RColorBrewer)
library(GADGET)
This is an R Markdown document. Markdown is a simple formatting syntax for authoring HTML, PDF, and MS Word documents. For more details on using R Markdown see http://rmarkdown.rstudio.com.
When you click the Knit button a document will be generated that includes both content as well as the output of any embedded R code chunks within the document. You can embed an R code chunk like this:
n = 500
set.seed(123)
create_xor = function(n, seed){
x2 = runif(n, -1, 1)
x3 = runif(n, -1, 1)
x1 = runif(n, -1, 1)
x4 = runif(n, -1, 1)
y = ifelse(x3>0, 3*x1, -3*x1) + ifelse(x4>0, 3*x2, -3*x2) + x3 + rnorm(n, sd = 0.3)
data.frame(x1, x2, x3, x4, y)
}
syn.data = create_xor(n, seed)
X = syn.data[, setdiff(names(syn.data), "y")]
features = colnames(X)
head(syn.data)
## x1 x2 x3 x4 y
## 1 -0.45275453 -0.4248450 -0.2927878 0.87605652 -0.5077988
## 2 0.18773387 0.5766103 -0.2671171 0.97600661 0.5875256
## 3 -0.67963037 -0.1820462 -0.4257997 -0.08736088 2.1538358
## 4 0.70686048 0.7660348 -0.8400542 -0.53877011 -5.2983926
## 5 0.69547832 0.8809346 -0.2690915 0.39097854 -0.4775255
## 6 -0.04422637 -0.9088870 -0.6439724 0.11326469 -2.9257822
task = makeRegrTask(data = syn.data, target = "y")
# tune Neural Network
set.seed(123)
ps = makeParamSet(
makeDiscreteParam("decay", values = c(0.5, 0.1, 1e-2, 1e-3, 1e-4, 1e-5)),
makeDiscreteParam("size", values = c(3, 5, 10, 20, 30))
)
ctrl = makeTuneControlGrid()
rdesc = makeResampleDesc("CV", iters = 5L)
res = tuneParams(makeLearner("regr.nnet", maxit = 1000), task = task, resampling = rdesc,
par.set = ps, control = ctrl, measures = list(mlr::mse, mlr::mae, mlr::rsq))
# fit Neural Network with best found HP on all training data
set.seed(123)
lrn = makeLearner("regr.nnet", maxit = 1000, size = res$x$size, decay = res$x$decay, trace = F)
model = mlr::train(task = task, learner = lrn)
testset = create_xor(n = 10000, seed=234)
pred = predict(model, newdata = testset)$data
measureRSQ(pred$truth, pred$response)
predict.function = function(model, newdata) predict(model, newdata = newdata)$data$response
syn.predictor = Predictor$new(model, data = syn.data[which(names(syn.data)!="y")], y = syn.data$y)
syn.effect = FeatureEffects$new(syn.predictor, grid.size = 20, method = "ice")
syn.tree = compute_tree(effect = syn.effect,
testdata = syn.data,
objective = "SS_L2_pd",
Z = c("x1", "x2", "x3", "x4"),
target.feature = "y",
n.split = 3,
n.quantiles = 50,
min.split = 1)
# extract full tree structure
extract_split_criteria(syn.tree)
## 🌳 Full Tree Structure:
## ────────────────────────────────────────
## [depth: 1 | id: 0 | intImp: 0.496 | x1.heter: 33696.83 | x2.heter: 31262.10 | x3.heter: 28324.53 | x4.heter: 27318.17 | # inst: 500]
## ✂️ x3 ≤ -0.015
## [depth: 2 | id: 1 | intImp: 0.233 | x1.heter: 848.67 | x2.heter: 15443.35 | x3.heter: 279.65 | x4.heter: 13791.66 | # inst: 241]
## ✂️ x3 ≤ -0.015 & x4 ≤ 0.031
## [depth: 3 | id: 1 | x1.heter: 406.25 | x2.heter: 474.79 | x3.heter: 111.38 | x4.heter: 96.50 | # inst: 111]
## 🌿 Leaf Node
## ✂️ x3 ≤ -0.015 & x4 > 0.031
## [depth: 3 | id: 2 | x1.heter: 426.02 | x2.heter: 292.93 | x3.heter: 166.28 | x4.heter: 258.05 | # inst: 130]
## 🌿 Leaf Node
## ✂️ x3 > -0.015
## [depth: 2 | id: 2 | intImp: 0.228 | x1.heter: 1191.62 | x2.heter: 15461.14 | x3.heter: 356.58 | x4.heter: 13423.35 | # inst: 259]
## ✂️ x3 > -0.015 & x4 ≤ 0.002
## [depth: 3 | id: 1 | x1.heter: 585.94 | x2.heter: 550.71 | x3.heter: 208.31 | x4.heter: 130.11 | # inst: 145]
## 🌿 Leaf Node
## ✂️ x3 > -0.015 & x4 > 0.002
## [depth: 3 | id: 2 | x1.heter: 584.70 | x2.heter: 598.23 | x3.heter: 145.63 | x4.heter: 137.57 | # inst: 114]
## 🌿 Leaf Node
# extract tree structure for one element in S
extract_split_criteria(syn.tree, "x1")
## Feature x1 - 🌳 Full partition tree:
## ────────────────────────────────────────
## [depth: 1 | id: 0 | intImp: 0.496 | x1.heter: 33696.83 | # inst: 500]
## ✂️ x3 ≤ -0.015
## [depth: 2 | id: 1 | intImp: 0.233 | x1.heter: 848.67 | # inst: 241]
## ✂️ x3 ≤ -0.015 & x4 ≤ 0.031
## [depth: 3 | id: 1 | x1.heter: 406.25 | # inst: 111]
## 🌿 Leaf Node
## ✂️ x3 ≤ -0.015 & x4 > 0.031
## [depth: 3 | id: 2 | x1.heter: 426.02 | # inst: 130]
## 🌿 Leaf Node
## ✂️ x3 > -0.015
## [depth: 2 | id: 2 | intImp: 0.228 | x1.heter: 1191.62 | # inst: 259]
## ✂️ x3 > -0.015 & x4 ≤ 0.002
## [depth: 3 | id: 1 | x1.heter: 585.94 | # inst: 145]
## 🌿 Leaf Node
## ✂️ x3 > -0.015 & x4 > 0.002
## [depth: 3 | id: 2 | x1.heter: 584.70 | # inst: 114]
## 🌿 Leaf Node
# create all plots by one call
plots = plot_tree(syn.tree, syn.effect, target.feature = "y")
# visulize solits
plot_tree_structure(syn.tree)